Acces parameter recovery results for PH with C model

Author

Milena Musial

Published

January 31, 2024

1 Setup

rm(list=ls())
libs<-c("rstan", "gdata", "bayesplot", "stringr", "dplyr", "ggplot2", "PerformanceAnalytics")
sapply(libs, require, character.only=TRUE)
Loading required package: rstan
Loading required package: StanHeaders

rstan version 2.26.22 (Stan version 2.26.1)
For execution on a local, multicore CPU with excess RAM we recommend calling
options(mc.cores = parallel::detectCores()).
To avoid recompilation of unchanged Stan programs, we recommend calling
rstan_options(auto_write = TRUE)
For within-chain threading using `reduce_sum()` or `map_rect()` Stan functions,
change `threads_per_chain` option:
rstan_options(threads_per_chain = 1)
Loading required package: gdata

Attaching package: 'gdata'
The following object is masked from 'package:stats':

    nobs
The following object is masked from 'package:utils':

    object.size
The following object is masked from 'package:base':

    startsWith
Loading required package: bayesplot
This is bayesplot version 1.11.1
- Online documentation and vignettes at mc-stan.org/bayesplot
- bayesplot theme set to bayesplot::theme_default()
   * Does _not_ affect other ggplot2 plots
   * See ?bayesplot_theme_set for details on theme setting
Loading required package: stringr
Loading required package: dplyr

Attaching package: 'dplyr'
The following objects are masked from 'package:gdata':

    combine, first, last, starts_with
The following objects are masked from 'package:stats':

    filter, lag
The following objects are masked from 'package:base':

    intersect, setdiff, setequal, union
Loading required package: ggplot2
Loading required package: PerformanceAnalytics
Loading required package: xts
Loading required package: zoo

Attaching package: 'zoo'
The following objects are masked from 'package:base':

    as.Date, as.Date.numeric

######################### Warning from 'xts' package ##########################
#                                                                             #
# The dplyr lag() function breaks how base R's lag() function is supposed to  #
# work, which breaks lag(my_xts). Calls to lag(my_xts) that you type or       #
# source() into this session won't work correctly.                            #
#                                                                             #
# Use stats::lag() to make sure you're not using dplyr::lag(), or you can add #
# conflictRules('dplyr', exclude = 'lag') to your .Rprofile to stop           #
# dplyr from breaking base R's lag() function.                                #
#                                                                             #
# Code in packages is not affected. It's protected by R's namespace mechanism #
# Set `options(xts.warn_dplyr_breaks_lag = FALSE)` to suppress this warning.  #
#                                                                             #
###############################################################################

Attaching package: 'xts'
The following objects are masked from 'package:dplyr':

    first, last
The following objects are masked from 'package:gdata':

    first, last

Attaching package: 'PerformanceAnalytics'
The following object is masked from 'package:graphics':

    legend
               rstan                gdata            bayesplot 
                TRUE                 TRUE                 TRUE 
             stringr                dplyr              ggplot2 
                TRUE                 TRUE                 TRUE 
PerformanceAnalytics 
                TRUE 
datapath <- '/fast/work/groups/ag_schlagenhauf/B01_FP1_WP2/WP2_ILT_CODE/02_Behav_and_Comp_Modeling/'
out_path <- '/fast/work/groups/ag_schlagenhauf/B01_FP1_WP2/WP2_ILT_CODE/02_Behav_and_Comp_Modeling/Output'
behavpath <- '/fast/work/groups/ag_schlagenhauf/B01_FP1_WP2/ILT_DATA'

# load files containing true parameters used as input for simulation
orig_file <- 'fit_n58_2024-05-07_bandit2arm_delta_PH_withC_DU_estimation1_delta0.999_stepsize0.1_treedepth12.rds'
orig_fit <- readRDS(file.path(out_path, orig_file)) # Stan model output

# load simulation output file containing y_pred and transformed parameters
sim_file <- 'sim_2024-05-07_bandit2arm_delta_PH_withC_sim_n58.rds'
sim_fit <- readRDS(file.path(out_path, 'Parameter_Recovery', sim_file)) # Stan model output

# load simulated data fitting results
recovery_file <- 'recovery_2024-05-10_bandit2arm_delta_PH_withC_n58.rds'
recovery_fit <- readRDS(file.path(out_path, 'Parameter_Recovery', recovery_file)) # Stan model output

color_scheme_set("mix-blue-pink")
# Load true parameters

## extract posterior means for all parameters to use them as input for simulation
  
### posterior means of parameters as input for simulation
true_mu_pr <- as.vector(summary(orig_fit, pars="mu_pr")$summary[, c("mean")]) 
true_sigma <- as.vector(summary(orig_fit, pars="sigma")$summary[, c("mean")]) 

true_A_pr <- as.vector(summary(orig_fit, pars="A_pr")$summary[, c("mean")]) 
true_tau_pr <- as.vector(summary(orig_fit, pars="tau_pr")$summary[, c("mean")]) 
true_gamma_pr <- as.vector(summary(orig_fit, pars="gamma_pr")$summary[, c("mean")]) 
true_C_pr <- as.vector(summary(orig_fit, pars="C_pr")$summary[, c("mean")]) 

### transformed parameters saved during simulation
sim_posterior <- extract(sim_fit)

true_A <- as.vector(sim_posterior$A[1,])
true_tau <- as.vector(sim_posterior$tau[1,])
true_gamma <- as.vector(sim_posterior$gamma[1,])
true_C <- as.vector(sim_posterior$C[1,])

true_mu_A <- as.vector(sim_posterior$mu_A[1])
true_mu_tau <- as.vector(sim_posterior$mu_tau[1])
true_mu_gamma <- as.vector(sim_posterior$mu_gamma[1])
true_mu_C <- as.vector(sim_posterior$mu_C[1])

## extract parameter values based on simulated data
recovered_mu_pr <- as.matrix(recovery_fit, pars = "mu_pr")
recovered_sigma <- as.matrix(recovery_fit, pars = "sigma")

recovered_A_pr <- as.matrix(recovery_fit, pars = "A_pr")
recovered_tau_pr <- as.matrix(recovery_fit, pars = "tau_pr")
recovered_gamma_pr <- as.matrix(recovery_fit, pars = "gamma_pr")
recovered_C_pr <- as.matrix(recovery_fit, pars = "C_pr")

recovered_A <- as.matrix(recovery_fit, pars = "A")
recovered_tau <- as.matrix(recovery_fit, pars = "tau")
recovered_gamma <- as.matrix(recovery_fit, pars = "gamma")
recovered_C <- as.matrix(recovery_fit, pars = "C")

recovered_A_mean <- as.vector(summary(recovery_fit, pars="A")$summary[, c("mean")]) 
recovered_tau_mean <- as.vector(summary(recovery_fit, pars="tau")$summary[, c("mean")]) 
recovered_gamma_mean <- as.vector(summary(recovery_fit, pars="gamma")$summary[, c("mean")]) 
recovered_C_mean <- as.vector(summary(recovery_fit, pars="C")$summary[, c("mean")]) 

recovered_mu_A <- as.matrix(recovery_fit, pars = "mu_A")
recovered_mu_tau <- as.matrix(recovery_fit, pars = "mu_tau")
recovered_mu_gamma <- as.matrix(recovery_fit, pars = "mu_gamma")
recovered_mu_C <- as.matrix(recovery_fit, pars = "mu_C")

2 Recovery plots

# mus (raw and transformed)
mcmc_recover_intervals(recovered_mu_pr, true_mu_pr, prob = 0.5, prob_outer = 0.95)

mcmc_recover_intervals(recovered_mu_A, true_mu_A, prob = 0.5, prob_outer = 0.95)

mcmc_recover_intervals(recovered_mu_tau, true_mu_tau, prob = 0.5, prob_outer = 0.95)

mcmc_recover_intervals(recovered_mu_gamma, true_mu_gamma, prob = 0.5, prob_outer = 0.95)

mcmc_recover_intervals(recovered_mu_C, true_mu_C, prob = 0.5, prob_outer = 0.95)

# sigma
mcmc_recover_intervals(recovered_sigma, true_sigma, prob = 0.5, prob_outer = 0.95)

# individual distances from mu
mcmc_recover_intervals(recovered_A_pr, true_A_pr, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10),
                                 rep(6,10), rep(7,10), rep(8,10), rep(9,10),
                                 rep(10,10), rep(11,10), rep(12,6)), # adapt last number to 2 or 6 depending on sample size
                       facet_args = list(ncol = 1))

mcmc_recover_intervals(recovered_tau_pr, true_tau_pr, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10),
                                 rep(6,10), rep(7,10), rep(8,10), rep(9,10),
                                 rep(10,10), rep(11,10), rep(12,6)),
                       facet_args = list(ncol = 1))

mcmc_recover_intervals(recovered_gamma_pr, true_gamma_pr, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10),
                                 rep(6,10), rep(7,10), rep(8,10), rep(9,10),
                                 rep(10,10), rep(11,10), rep(12,6)),
                       facet_args = list(ncol = 1))

mcmc_recover_intervals(recovered_C_pr, true_C_pr, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10),
                                 rep(6,10), rep(7,10), rep(8,10), rep(9,10),
                                 rep(10,10), rep(11,10), rep(12,6)),
                       facet_args = list(ncol = 1))

# transformed individual parameters
mcmc_recover_intervals(recovered_A, true_A, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10),
                                 rep(6,10), rep(7,10), rep(8,10), rep(9,10),
                                 rep(10,10), rep(11,10), rep(12,6)),
                       facet_args = list(ncol = 1))

mcmc_recover_intervals(recovered_tau, true_tau, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10),
                                 rep(6,10), rep(7,10), rep(8,10), rep(9,10),
                                 rep(10,10), rep(11,10), rep(12,6)),
                       facet_args = list(ncol = 1))

mcmc_recover_intervals(recovered_gamma, true_gamma, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10),
                                 rep(6,10), rep(7,10), rep(8,10), rep(9,10),
                                 rep(10,10), rep(11,10), rep(12,6)),
                       facet_args = list(ncol = 1))

mcmc_recover_intervals(recovered_C, true_C, prob = 0.5, prob_outer = 0.95, 
                       batch = c(rep(1,10), rep(2,10), rep(3,10), rep(4,10), rep(5,10),
                                 rep(6,10), rep(7,10), rep(8,10), rep(9,10),
                                 rep(10,10), rep(11,10), rep(12,6)),
                       facet_args = list(ncol = 1))

3 Correlation btw. true and recovered inidivual parameters

param_df <- data.frame(true_A,true_tau,true_gamma,true_C,recovered_A_mean,recovered_tau_mean,recovered_gamma_mean,recovered_C_mean)
cor(param_df)
                          true_A     true_tau   true_gamma      true_C
true_A                1.00000000  0.526142169  0.030751149 -0.29241480
true_tau              0.52614217  1.000000000  0.009123349 -0.04609718
true_gamma            0.03075115  0.009123349  1.000000000 -0.05625588
true_C               -0.29241480 -0.046097176 -0.056255879  1.00000000
recovered_A_mean      0.83048862  0.592539600  0.029477191 -0.32226767
recovered_tau_mean    0.60364748  0.861302226 -0.047747094  0.05660818
recovered_gamma_mean -0.02950892 -0.265659097  0.214512844 -0.22486455
recovered_C_mean     -0.34854897  0.001779931 -0.006761706  0.85413957
                     recovered_A_mean recovered_tau_mean recovered_gamma_mean
true_A                     0.83048862         0.60364748          -0.02950892
true_tau                   0.59253960         0.86130223          -0.26565910
true_gamma                 0.02947719        -0.04774709           0.21451284
true_C                    -0.32226767         0.05660818          -0.22486455
recovered_A_mean           1.00000000         0.66336373          -0.13007293
recovered_tau_mean         0.66336373         1.00000000          -0.28517217
recovered_gamma_mean      -0.13007293        -0.28517217           1.00000000
recovered_C_mean          -0.41415885         0.02270963          -0.30485496
                     recovered_C_mean
true_A                   -0.348548970
true_tau                  0.001779931
true_gamma               -0.006761706
true_C                    0.854139567
recovered_A_mean         -0.414158848
recovered_tau_mean        0.022709626
recovered_gamma_mean     -0.304854955
recovered_C_mean          1.000000000
chart.Correlation(param_df, histogram=TRUE, pch=19)
Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter

Warning in par(usr): argument 1 does not name a graphical parameter